{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transformers 설치 방법\n",
    "! pip install transformers datasets evaluate accelerate\n",
    "# 마지막 릴리스 대신 소스에서 설치하려면, 위 명령을 주석으로 바꾸고 아래 명령을 해제하세요.\n",
    "# ! pip install git+https://github.com/huggingface/transformers.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 이미지 캡셔닝[[image-captioning]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이미지 캡셔닝(Image captioning)은 주어진 이미지에 대한 캡션을 예측하는 작업입니다. \n",
    "이미지 캡셔닝은 시각 장애인이 다양한 상황을 탐색하는 데 도움을 줄 수 있도록 시각 장애인을 보조하는 등 실생활에서 흔히 활용됩니다. \n",
    "따라서 이미지 캡셔닝은 이미지를 설명함으로써 사람들의 콘텐츠 접근성을 개선하는 데 도움이 됩니다.\n",
    "\n",
    "이 가이드에서는 소개할 내용은 아래와 같습니다:\n",
    "\n",
    "* 이미지 캡셔닝 모델을 파인튜닝합니다.\n",
    "* 파인튜닝된 모델을 추론에 사용합니다.\n",
    "\n",
    "시작하기 전에 필요한 모든 라이브러리가 설치되어 있는지 확인하세요:\n",
    "\n",
    "```bash\n",
    "pip install transformers datasets evaluate -q\n",
    "pip install jiwer -q\n",
    "```\n",
    "\n",
    "Hugging Face 계정에 로그인하면 모델을 업로드하고 커뮤니티에 공유할 수 있습니다. \n",
    "토큰을 입력하여 로그인하세요."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 포켓몬 BLIP 캡션 데이터세트 가져오기[[load-the-pokmon-blip-captions-dataset]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "{이미지-캡션} 쌍으로 구성된 데이터세트를 가져오려면 🤗 Dataset 라이브러리를 사용합니다. \n",
    "PyTorch에서 자신만의 이미지 캡션 데이터세트를 만들려면 [이 노트북](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/GIT/Fine_tune_GIT_on_an_image_captioning_dataset.ipynb)을 참조하세요."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "ds = load_dataset(\"lambdalabs/pokemon-blip-captions\")\n",
    "ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "DatasetDict({\n",
    "    train: Dataset({\n",
    "        features: ['image', 'text'],\n",
    "        num_rows: 833\n",
    "    })\n",
    "})\n",
    "```\n",
    "\n",
    "이 데이터세트는 `image`와 `text`라는 두 특성을 가지고 있습니다.\n",
    "\n",
    "<Tip>\n",
    "\n",
    "많은 이미지 캡션 데이터세트에는 이미지당 여러 개의 캡션이 포함되어 있습니다. \n",
    "이러한 경우, 일반적으로 학습 중에 사용 가능한 캡션 중에서 무작위로 샘플을 추출합니다. \n",
    "\n",
    "</Tip>\n",
    "\n",
    "`train_test_split` 메소드를 사용하여 데이터세트의 학습 분할을 학습 및 테스트 세트로 나눕니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = ds[\"train\"].train_test_split(test_size=0.1)\n",
    "train_ds = ds[\"train\"]\n",
    "test_ds = ds[\"test\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "학습 세트의 샘플 몇 개를 시각화해 봅시다.\n",
    "Let's visualize a couple of samples from the training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from textwrap import wrap\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def plot_images(images, captions):\n",
    "    plt.figure(figsize=(20, 20))\n",
    "    for i in range(len(images)):\n",
    "        ax = plt.subplot(1, len(images), i + 1)\n",
    "        caption = captions[i]\n",
    "        caption = \"\\n\".join(wrap(caption, 12))\n",
    "        plt.title(caption)\n",
    "        plt.imshow(images[i])\n",
    "        plt.axis(\"off\")\n",
    "\n",
    "\n",
    "sample_images_to_visualize = [np.array(train_ds[i][\"image\"]) for i in range(5)]\n",
    "sample_captions = [train_ds[i][\"text\"] for i in range(5)]\n",
    "plot_images(sample_images_to_visualize, sample_captions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_training_images_image_cap.png\" alt=\"Sample training images\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 데이터세트 전처리[[preprocess-the-dataset]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "데이터세트에는 이미지와 텍스트라는 두 가지 양식이 있기 때문에, 전처리 파이프라인에서 이미지와 캡션을 모두 전처리합니다.\n",
    "\n",
    "전처리 작업을 위해, 파인튜닝하려는 모델에 연결된 프로세서 클래스를 가져옵니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor\n",
    "\n",
    "checkpoint = \"microsoft/git-base\"\n",
    "processor = AutoProcessor.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "프로세서는 내부적으로 크기 조정 및 픽셀 크기 조정을 포함한 이미지 전처리를 수행하고 캡션을 토큰화합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transforms(example_batch):\n",
    "    images = [x for x in example_batch[\"image\"]]\n",
    "    captions = [x for x in example_batch[\"text\"]]\n",
    "    inputs = processor(images=images, text=captions, padding=\"max_length\")\n",
    "    inputs.update({\"labels\": inputs[\"input_ids\"]})\n",
    "    return inputs\n",
    "\n",
    "\n",
    "train_ds.set_transform(transforms)\n",
    "test_ds.set_transform(transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "데이터세트가 준비되었으니 이제 파인튜닝을 위해 모델을 설정할 수 있습니다."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 기본 모델 가져오기[[load-a-base-model]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[\"microsoft/git-base\"](https://huggingface.co/microsoft/git-base)를 [`AutoModelForCausalLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM) 객체로 가져옵니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 평가[[evaluate]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이미지 캡션 모델은 일반적으로 [Rouge 점수](https://huggingface.co/spaces/evaluate-metric/rouge) 또는 [단어 오류율(Word Error Rate)](https://huggingface.co/spaces/evaluate-metric/wer)로 평가합니다. \n",
    "이 가이드에서는 단어 오류율(WER)을 사용합니다. \n",
    "\n",
    "이를 위해 🤗 Evaluate 라이브러리를 사용합니다. \n",
    "WER의 잠재적 제한 사항 및 기타 문제점은 [이 가이드](https://huggingface.co/spaces/evaluate-metric/wer)를 참조하세요."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from evaluate import load\n",
    "import torch\n",
    "\n",
    "wer = load(\"wer\")\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predicted = logits.argmax(-1)\n",
    "    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)\n",
    "    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)\n",
    "    wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)\n",
    "    return {\"wer_score\": wer_score}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 학습![[train!]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이제 모델 파인튜닝을 시작할 준비가 되었습니다. 이를 위해 🤗 [Trainer](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer)를 사용합니다. \n",
    "\n",
    "먼저, [TrainingArguments](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.TrainingArguments)를 사용하여 학습 인수를 정의합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "model_name = checkpoint.split(\"/\")[1]\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"{model_name}-pokemon\",\n",
    "    learning_rate=5e-5,\n",
    "    num_train_epochs=50,\n",
    "    fp16=True,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    gradient_accumulation_steps=2,\n",
    "    save_total_limit=3,\n",
    "    eval_strategy=\"steps\",\n",
    "    eval_steps=50,\n",
    "    save_strategy=\"steps\",\n",
    "    save_steps=50,\n",
    "    logging_steps=50,\n",
    "    remove_unused_columns=False,\n",
    "    push_to_hub=True,\n",
    "    label_names=[\"labels\"],\n",
    "    load_best_model_at_end=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "학습 인수를 데이터세트, 모델과 함께 🤗 Trainer에 전달합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_ds,\n",
    "    eval_dataset=test_ds,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "학습을 시작하려면 [Trainer](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer) 객체에서 [train()](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer.train)을 호출하기만 하면 됩니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "학습이 진행되면서 학습 손실이 원활하게 감소하는 것을 볼 수 있습니다.\n",
    "\n",
    "학습이 완료되면 모든 사람이 모델을 사용할 수 있도록 [push_to_hub()](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer.push_to_hub) 메소드를 사용하여 모델을 허브에 공유하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 추론[[inference]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`test_ds`에서 샘플 이미지를 가져와 모델을 테스트합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "\n",
    "url = \"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png\"\n",
    "image = Image.open(requests.get(url, stream=True).raw)\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/test_image_image_cap.png\" alt=\"Test image\"/>\n",
    "</div>\n",
    "    \n",
    "모델에 사용할 이미지를 준비합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "inputs = processor(images=image, return_tensors=\"pt\").to(device)\n",
    "pixel_values = inputs.pixel_values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`generate`를 호출하고 예측을 디코딩합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_ids = model.generate(pixel_values=pixel_values, max_length=50)\n",
    "generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
    "print(generated_caption)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "a drawing of a pink and blue pokemon\n",
    "```\n",
    "\n",
    "파인튜닝된 모델이 꽤 괜찮은 캡션을 생성한 것 같습니다!"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
